Jacobian = Union{AbstractArray, IdentityMatrix, SimpleSparse}

Base.copy(::Nothing) = Nothing()

abstract type NestedDict{R, S} <: AbstractDict{R, S} end

Base.keytype(nd::NestedDict{R, S}) where {R, S} = R
Base.valtype(nd::NestedDict{R, S}) where {R, S} = S

Base.show(io::IO, nd::NestedDict) = print(io, "<$(typeof(nd)) outputs=$(nd.outputs), inputs=$(nd.inputs)>")
Base.show(io::IO, ::MIME"text/plain", nd::NestedDict) = print(io, "<$(typeof(nd)) outputs=$(nd.outputs), inputs=$(nd.inputs)>")
    
Base.iterate(nd::NestedDict) = iterate(nd.outputs)
Base.iterate(nd::NestedDict, i) = iterate(nd.outputs, i)

function Base.union(nd::NestedDict, s)
    merged = typeof(nd)(nd.nesteddict; outputs=nd.outputs, inputs=nd.inputs)
    merge!(merged, s)
    return merged
end

Base.getindex(nd::NestedDict, k::AbstractString) = nd.nesteddict[k]

function Base.getindex(nd::NestedDict, k::Tuple) 
    o, i = k
    o = o isa Colon ? nd.outputs : o
    i = i isa Colon ? nd.inputs : i

    if o isa AbstractString
        if i isa AbstractString
            return nd.nesteddict[o][i]
        else
            return subdict(nd.nesteddict[o], i)
        end
    else
        i = i isa AbstractString ? (i,) : i
        return getfield(Main, nameof(typeof(nd)))(Dict(oo => subdict(nd.nesteddict[oo], i) for oo ∈ o); outputs=o, inputs=i)
    end
end

function Base.getindex(nd::NestedDict, k::Union{OrderedSet, Set, Vector})
    return getfield(Main, nameof(typeof(nd)))(Dict(oo => nd.nesteddict[oo] for oo ∈ k); outputs=k, inputs=nd.inputs)
end

Base.get(nd::NestedDict, args...; kwargs...) = get(nd.nesteddict, args...; kwargs...)

# MAYBE
# Base.setindex!(nd::NestedDict, v, k) = nd.nesteddict[k] = v

Base.length(nd::NestedDict) = length(nd.nesteddict)

# MAYBE
Base.keys(nd::NestedDict) = keys(nd.nesteddict)
Base.values(nd::NestedDict) = values(nd.nesteddict)
Base.pairs(nd::NestedDict) = pairs(nd.nesteddict)

# MAYBE
Base.copy(nd::NestedDict) = typeof(nd)(nd)

function Base.merge!(nd::NestedDict, J::NestedDict)
    (isempty(J.outputs) || isempty(J.inputs)) && (return nothing)
    if Set(nd.inputs) != Set(J.inputs)
        error("Cannot merge $(typeof(nd)) with non-overlapping inputs $(symdiff(Set(nd.inputs), Set(J.inputs)))")
    end
    if !isempty(Set(nd.outputs) ∩ Set(J.outputs))
        error("Cannot merge $(typeof(nd)) with overlapping outputs $(Set(nd.outputs) ∩ Set(J.outputs)))")
    end
    nd.outputs = nd.outputs ∪ J.outputs
    nd.nesteddict = merge(nd.nesteddict, J.nesteddict)
end


function complete(nd::NestedDict, filler)
    nesteddict = Dict()
    for o ∈ nd.outputs
        nesteddict[o] = Dict(nd.nesteddict[o])
        for i ∈ nd.inputs
            if i ∉ nesteddict[o]
                nesteddict[o][i] = filler
            end
        end
    end
    return typeof(nd)(nesteddict; outputs=nd.outputs, inputs=nd.inputs)
end

deduplicate(v::AbstractVector) = unique(v)

function subdict(d, ks)
    return Dict(k => d[k] for k ∈ ks if k ∈ keys(d))
end

mutable struct JacobianDict{R, S} <: NestedDict{R, S}
    nesteddict
    outputs
    inputs
    name
    T

    function JacobianDict{R, S}(nd::NestedDict{R, S}; outputs = nothing, inputs=nothing, name=nothing, T=nothing, check=false) where {R, S} 
        check && ensure_valid_jacobiandict(nd)
        return new{R, S}(nd.nesteddict, nd.outputs, nd.inputs, nd.name, nd.T) 
    end
        
    function JacobianDict{R, S}(nd::AbstractDict{R, S}; outputs = nothing, inputs=nothing, name=nothing, T=nothing, check=false) where {R, S}
        check && ensure_valid_jacobiandict(nd)

        _outputs, _inputs = copy(outputs), copy(inputs)

        (_outputs isa Nothing) && (_outputs = OrderedSet(keys(nd)))
        if _inputs isa Nothing
            _inputs = OrderedSet()
            for v ∈ values(nd)
                if v isa Dict
                    _inputs = _inputs ∪ OrderedSet(keys(v))
                else
                    _inputs = v isa AbstractArray ? _inputs ∪ OrderedSet(v) : _inputs ∪ OrderedSet([v])
                end
            end
        end

        if isempty(_outputs) || isempty(_inputs)
            _outputs = OrderedSet()
            _inputs = OrderedSet()
        end

        _outputs = OrderedSet(_outputs)
        _inputs = OrderedSet(_inputs)
        _name = name isa Nothing ? "NestedDict" : name

        return new{R, S}(nd, _outputs, _inputs, _name, T)
    end
end

JacobianDict(nd::AbstractDict{R, S}; outputs = nothing, inputs=nothing, name=nothing, T=nothing, check=false) where {R, S} = JacobianDict{R, S}(nd; outputs=outputs, inputs=inputs, name=name, T=T, check=check)

function identity(ks)
    return JacobianDict(Dict{Any, Any}(k => Dict{Any, Any}(k => IdentityMatrix()) for k ∈ ks); outputs=ks, inputs=ks)
end

function addinputs(J::JacobianDict)
    inputs = [x for x ∈ J.inputs if x ∉ J.outputs]
    return J ∪ identity(inputs)
end

Base.:*(J::JacobianDict, K::JacobianDict) = compose(J, K)
Base.:*(J::JacobianDict, x::Bijection) = remap(J, x)
Base.:*(x::Bijection, J::JacobianDict) = remap(J, x)
Base.:*(J::JacobianDict, x::Any) = apply(J, x)

function remap(J::JacobianDict, x::Bijection)
    isempty(x) && (return J)
    nesteddict = x * J.nesteddict
    for o ∈ keys(nesteddict)
        nesteddict[o] = x * nesteddict[o]
    end
    return JacobianDict(nesteddict; inputs = x*J.inputs, outputs = x*J.outputs)
end

Base.isempty(J::JacobianDict) = isempty(J.outputs) || isempty(J.inputs)

function compose(J::JacobianDict, K::JacobianDict)
    if !(J.T isa Nothing) && !(K.T isa Nothing) && (J.T != K.T)
        error("Trying to multiply JacobianDicts with inconsistent dimensions $(J.T) and $(K.T)")
    end

    o_list = J.outputs
    m_list = Tuple(Set(J.inputs) ∩ Set(K.outputs))
    i_list = K.inputs

    J_om = J.nesteddict
    J_mi = K.nesteddict
    J_oi = Dict()

    for o ∈ o_list
        J_oi[o] = Dict()
        for i ∈ i_list
            Jout = nothing
            for m ∈ m_list
                if (m ∈ keys(J_om[o])) && (i ∈ keys(J_mi[m]))
                    if Jout isa Nothing
                        Jout = J_om[o][m] * J_mi[m][i]
                    else
                        Jout = J_om[o][m] * J_mi[m][i] + Jout
                    end
                end
            end
            if !(Jout isa Nothing)
                J_oi[o][i] = Jout
            end
        end
    end

    return JacobianDict(J_oi; outputs=o_list, inputs=i_list)
end

function apply(J::JacobianDict, x::Union{ImpulseDict, Dict{<:AbstractString, <:AbstractArray}})
    x = ImpulseDict(x)
    inputs = keys(x) ∩ set(J.inputs)
    J_oi = J.nesteddict
    y = Dict()

    for o ∈ J.outputs
        y[o] = zeros(x.T)
        J_i = J_oi[o]
        for i ∈ inputs
            if i ∈ J_i
                y[o] += J_i[i] * x[i]
            end
        end
    end
    return x ∪ ImpulseDict(y; T=x.t)
end

function pack(J::JacobianDict; T=nothing)
    if T isa Nothing
        if !(J.T isa Nothing)
            T = J.T
        else
            error("Trying to pack $J into matrix, but do not know $T")
        end
    else
        if !(J.T isa Nothing) && (T != J.T)
            error("JacobianDict has dimension $(J.T), but trying to pack it with alternate dimension $T")
        end
    end

    K = zeros(length(J.outputs) * T, length(J.inputs) * T)
    for (io, o) ∈ enumerate(J.outputs)
        for (ii, i) ∈ enumerate(J.inputs)
            J_oi = get(J[o], i, nothing)
            if !(J_oi isa Nothing)
                K[T*(io-1)+1:T*io, T*(ii-1)+1:T*ii] = make_matrix(J_oi, T)
            else
                K[T*(io-1)+1:T*io, T*(ii-1)+1:T*ii] .= 0 
            end
        end
    end
    return K
end

function unpack(bigjac, outputs, inputs, T)
    jacdict = Dict()
    for (io, o) ∈ enumerate(outputs)
        jacdict[o] = Dict()
        for (ii, i) ∈ enumerate(inputs)
            jacdict[o][i] = bigjac[T*(io-1)+1:T*io, T*(ii-1)+1:T*ii]
        end
    end
    return JacobianDict(jacdict; outputs=outputs, inputs=inputs, T=T)
end


struct FactoredJacobianDict
    H_U_factored
    targets
    unknowns
    T

    function FactoredJacobianDict(jacobian_dict::JacobianDict; T=nothing)
        if jacobian_dict.T isa Nothing
            if T isa Nothing
                error("Trying to factor (solve) $(jacobian_dict) but do not know T")
            end
            _T = T
        else
            _T  = jacobian_dict.T
        end

        H_U = pack(jacobian_dict; T=_T)
        _targets = jacobian_dict.outputs
        _unknowns = jacobian_dict.inputs

        if length(_targets) != length(_unknowns)
            error("Trying to factor JacobianDict unequal number of inputs (unknowns) $(unknowns) and outputs (targets) $(targets)")
        end

        #TODO maybe use misc function
        _H_U_factored = LinearAlgebra.lu(H_U)

        return new(_H_U_factored, _targets, _unknowns, _T)
    end
end

Base.show(io::IO, fjd::FactoredJacobianDict) = print(io, "<$(nameof(typeof(fjd))) unknowns=$(fjd.unknowns), targets=$(fjd.targets)>")

#TODO maybe use misc function
function to_jacobian_dict(fjd::FactoredJacobianDict)
    L, U, p = fjd.H_U_factored
    b = I(fjd.T * length(fjd.unknowns))
    solved = U\(L\b[p])
    return unpack(-solved, fjd.unknowns, fjd.targets, fjd.T)
end

Base.:*(fjd::FactoredJacobianDict, J::JacobianDict) = compose(fjd, J)
Base.:*(fjd::FactoredJacobianDict, x::Bijection) = remap(fjd, x)
Base.:*(x::Bijection, fjd::FactoredJacobianDict) = remap(fjd, x)
Base.:*(fjd::FactoredJacobianDict, x::Any) = apply(fjd, x)

function remap(fjd::FactoredJacobianDict, x::Bijection)
    isempty(x) && (return fjd)
    newself = copy(fjd)
    newself.unknowns = x * fjd.unknowns
    newself.targets = x * fjd.targets
    return newself
end

function compose(fjd::FactoredJacobianDict, J::JacobianDict)
    Jsub = pack(J[[o for o ∈ fjd.targets if o ∈ J.outputs]]; T=fjd.T)
    L, U, p = fjd.H_U_factored
    out = -U\(L\Jsub[p, :])
    return unpack(out, fjd.unknowns, J.inputs, fjd.T)
end

function apply(fjd::FactoredJacobianDict, x::Union{ImpulseDict, Dict{<:AbstractString,<:AbstractArray}})
    xsub = pack(get(ImpulseDict(x), fjd.targets); force_order=fjd.targets)
    L, U, p = fjd.H_U_factored
    out = -U\(L\xsub[p])
    return unpack(out, fjd.unknowns, fjd.T)
end

function factored(J::JacobianDict; T=None) 
    FactoredJacobianDict(J, T)
end

function ensure_valid_jacobiandict(d)
    if !isempty(d) && !(d isa JacobianDict)
        if !(first(keys(d)) isa AbstractString)
            error("The Dict argument must have keys with type String to indicate output names.")
        end

        jac_o_dict = first(values(d))
        if jac_o_dict isa Dict
            if !isempty(jac_o_dict)
                if !(first(keys(jac_o_dict)) isa AbstractString)
                    error("The values of the Dict argument $d must be Dicts with keys of type String to indicate input names.")
                end
                jac_o_i = first(values(jac_o_dict))
                if !(jac_o_i isa Jacobian)
                    error("The Dict argument's values must be Dicts with values of type Jacobian.")
                else
                    if (jac_o_i isa AbstractArray) && (size(jac_o_i)[1] != size(jac_o_i)[2])
                        error("The Jacobians must be square matrices of type Jacobian.")
                    end
                end
            end
        else
            error("The argument $d must be of type Dic`, with keys of type String and values of type Jacobian.")
        end
    end
end


function verify_saved_jacobian(block_name, Js, outputs, inputs, T)
    (block_name ∉ keys(Js)) && (return false)

    J = Js[block_name]
    if !(J isa JacobianDict)
        @warn "Js[$(block_name)] is not a JacobianDict."
        return false
    end

    if Set(outputs) ⊈ Set(J.outputs)
        miss = setdiff(Set(outputs), Set(J.outputs))
        @warn "Js[$(block_name)] misses required outputs $(miss)."
        return false
    end

    if Set(inputs) ⊈ Set(J.inputs)
        miss = setdiff(Set(inputs), Set(J.inputs))
        @warn "Js[$(block_name)] misses required inputs $(miss)."
        return false
    end

    if !(T isa Nothing)
        Tsaved = size(J[J.outputs[1]][J.inputs[1]])[end]
        if T != Tsaved
            @warn "Js[$(block_name)] has length $(Tsaved), but you asked for $(T)"
            return false
        end
    end

    return true
end
